# Copyright 2017 PerfKitBenchmarker Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Relational DB class.

This is the base implementation of all relational db.
"""

import abc
import datetime
import logging
from absl import flags
from perfkitbenchmarker import background_tasks
from perfkitbenchmarker import data
from perfkitbenchmarker import errors
from perfkitbenchmarker import os_types
from perfkitbenchmarker import resource
from perfkitbenchmarker import sample
from perfkitbenchmarker import sql_engine_utils

# TODO(chunla): Move IAAS flag to file
flags.DEFINE_string(
    'db_engine', None, 'Managed database flavor to use (mysql, postgres)'
)
flags.DEFINE_string(
    'db_engine_version',
    None,
    'Version of the database flavor selected, e.g. 5.7',
)
flags.DEFINE_string(
    'database_name',
    None,
    'Name of the database to create. Defaults to pkb-db-[run-uri]',
)
flags.DEFINE_string(
    'database_username',
    None,
    'Database username. Defaults to pkb-db-user-[run-uri]',
)
flags.DEFINE_string(
    'database_password',
    None,
    'Database password. Defaults to a random 10-character alpha-numeric string',
)
flags.DEFINE_boolean(
    'db_high_availability',
    False,
    'Specifies if the database should be high availability',
)
flags.DEFINE_string(
    'db_high_availability_type',
    None,
    'Specifies high availableity type. (AOAG, FCIS2D, FCIMW)',
)
flags.DEFINE_boolean(
    'db_backup_enabled', True, 'Whether or not to enable automated backups'
)
flags.DEFINE_list(
    'db_zone',
    None,
    'zone or region to launch the database in. '
    "Defaults to the client vm's zone.",
)
flags.DEFINE_list(
    'db_replica_zones',
    [],
    'zones to launch the database replicas in. ',
)
flags.DEFINE_string(
    'client_vm_zone', None, 'zone or region to launch the client in. '
)

flags.DEFINE_integer(
    'client_vm_count', None, 'Number of client vms to provision '
)
flags.DEFINE_string('db_machine_type', None, 'Machine type of the database.')
flags.DEFINE_integer('db_cpus', None, 'Number of Cpus in the database.')
flags.DEFINE_string(
    'db_memory',
    None,
    'Amount of Memory in the database.  Uses the same format '
    'string as custom machine memory type.',
)
flags.DEFINE_integer('db_disk_size', None, 'Size of the database disk in GB.')
flags.DEFINE_integer(
    'db_num_striped_disks',
    None,
    'The number of data disks to stripe together to form one.',
)
flags.DEFINE_string('db_disk_type', None, 'Disk type of the database.')
flags.DEFINE_integer(
    'db_disk_iops',
    None,
    'Disk IOPs to provision for database disks, if provisioning is applicable'
    ' or required. IOPs applies to each disk.',
)
flags.DEFINE_integer(
    'db_disk_throughput',
    None,
    'Disk throughput to provision for database disks, if provisioning is'
    'applicable or required. Throughput applies to each disk.',
)

flags.DEFINE_integer(
    'managed_db_azure_compute_units', None, 'Number of Dtus in the database.'
)
flags.DEFINE_string(
    'managed_db_tier', None, 'Tier in azure. (Basic, Standard, Premium).'
)
flags.DEFINE_string('server_vm_os_type', None, 'OS type of the client vm.')
flags.DEFINE_string('client_vm_os_type', None, 'OS type of the client vm.')
flags.DEFINE_string(
    'server_gcp_min_cpu_platform', None, 'Cpu platform of the server vm.'
)
flags.DEFINE_string(
    'client_gcp_min_cpu_platform', None, 'CPU platform of the client vm.'
)
flags.DEFINE_string(
    'client_vm_machine_type', None, 'Machine type of the client vm.'
)
flags.DEFINE_integer('client_vm_cpus', None, 'Number of Cpus in the client vm.')
flags.DEFINE_string(
    'client_vm_memory',
    None,
    'Amount of Memory in the vm.  Uses the same format '
    'string as custom machine memory type.',
)
flags.DEFINE_integer(
    'client_vm_disk_size', None, 'Size of the client vm disk in GB.'
)
flags.DEFINE_string('client_vm_disk_type', None, 'Disk type of the client vm.')
flags.DEFINE_integer(
    'client_vm_disk_iops',
    None,
    'Disk iops of the database on AWS for client vm.',
)
flags.DEFINE_boolean(
    'use_managed_db',
    True,
    'If true, uses the managed MySql '
    'service for the requested cloud provider. If false, uses '
    'MySql installed on a VM.',
)
flags.DEFINE_list(
    'db_flags',
    '',
    'Flags to apply to the managed relational database '
    "on the cloud that's being used. Example: "
    'binlog_cache_size=4096,innodb_log_buffer_size=4294967295',
)
flags.DEFINE_integer(
    'innodb_buffer_pool_size',
    None,
    'Size of the innodb buffer pool size in GB. '
    'If unset, innodb_buffer_pool_ratio is used.',
)

flags.DEFINE_float(
    'innodb_buffer_pool_ratio',
    0.25,
    'Ratio of the innodb buffer pool size to VM memory. '
    'Ignored if innodb_buffer_pool_size is set.',
    lower_bound=0,
    upper_bound=1,
)

flags.DEFINE_bool(
    'mysql_bin_log', False, 'Flag to turn binary logging on. Defaults to False'
)
flags.DEFINE_integer(
    'innodb_log_file_size',
    1000,
    'Size of the log file in MB. Defaults to 1000M.',
)
flags.DEFINE_integer(
    'postgres_shared_buffer_size',
    None,
    'Size of the shared buffer size in GB. '
    'If unset, postgres_shared_buffer_ratio is used.',
)

flags.DEFINE_float(
    'postgres_shared_buffer_ratio',
    0.25,
    'Ratio of the shared buffer size to VM memory. '
    'Ignored if postgres_shared_buffer_size is set.',
    lower_bound=0,
    upper_bound=1,
)

OPTIMIZE_DB_SYSCTL_CONFIG = flags.DEFINE_bool(
    'optimize_db_sysctl_config',
    True,
    'Flag to run sysctl optimization for IAAS relational database.',
)

SERVER_GCE_NUM_LOCAL_SSDS = flags.DEFINE_integer(
    'server_gce_num_local_ssds',
    0,
    'The number of ssds that should be added to the Server. Note '
    'that the flag only applies for vms that can have a variable '
    'number of local SSDs.',
)
SERVER_GCE_SSD_INTERFACE = flags.DEFINE_enum(
    'server_gce_ssd_interface',
    'SCSI',
    ['SCSI', 'NVME'],
    'The ssd interface for GCE local SSD.',
)

ENABLE_DATA_CACHE = flags.DEFINE_bool(
    'gcp_db_enable_data_cache', False, 'Whether to enable data cache.'
)
METRICS_TIME_FORMAT = '%Y-%m-%dT%H:%M:%SZ'
METRICS_COLLECTION_DELAY_SECONDS = 165


FLAGS = flags.FLAGS


DEFAULT_MYSQL_PORT = 3306
DEFAULT_POSTGRES_PORT = 5432
DEFAULT_SQLSERVER_PORT = 1433

DEFAULT_PORTS = {
    sql_engine_utils.MYSQL: DEFAULT_MYSQL_PORT,
    sql_engine_utils.POSTGRES: DEFAULT_POSTGRES_PORT,
    sql_engine_utils.SQLSERVER: DEFAULT_SQLSERVER_PORT,
}

CLEAR_WAIT_STATS_SQL = "DBCC SQLPERF ('sys.dm_os_wait_stats', CLEAR);"
CAPTURE_IO_STATS_SQL = (
    'select DB_NAME(database_id) as DBName,* from '
    'sys.dm_io_virtual_file_stats(NULL,NULL)'
)


class RelationalDbPropertyNotSetError(Exception):
  pass


class RelationalDbEngineNotFoundError(Exception):
  pass


class UnsupportedError(Exception):
  pass


def GetRelationalDbClass(cloud, is_managed_db, engine):
  """Get the RelationalDb class corresponding to 'cloud'.

  Args:
    cloud: name of cloud to get the class for
    is_managed_db: is the database self managed or database a a service
    engine: database engine type

  Returns:
    BaseRelationalDb class with the cloud attribute of 'cloud'.
  """
  try:
    return resource.GetResourceClass(
        BaseRelationalDb, CLOUD=cloud, IS_MANAGED=is_managed_db, ENGINE=engine
    )
  except errors.Resource.SubclassNotFoundError:
    pass

  return resource.GetResourceClass(
      BaseRelationalDb, CLOUD=cloud, IS_MANAGED=is_managed_db
  )


def VmsToBoot(vm_groups):
  # TODO(user): Enable replications.
  return {
      name: spec  # pylint: disable=g-complex-comprehension
      for name, spec in vm_groups.items()
      if name == 'clients'
      or name == 'default'
      or name == 'controller'
      or (not FLAGS.use_managed_db and name == 'servers')
      or (not FLAGS.use_managed_db and name == 'servers_replicas')
  }


class BaseRelationalDb(resource.BaseResource):
  """Object representing a relational database Service."""

  IS_MANAGED = True
  RESOURCE_TYPE = 'BaseRelationalDb'
  REQUIRED_ATTRS = ['CLOUD', 'IS_MANAGED']

  def __init__(self, relational_db_spec):
    """Initialize the managed relational database object.

    Args:
      relational_db_spec: spec of the managed database.
    """
    super().__init__(
        enable_freeze_restore=relational_db_spec.enable_freeze_restore,
        create_on_restore_error=relational_db_spec.create_on_restore_error,
        delete_on_freeze_error=relational_db_spec.delete_on_freeze_error,
    )
    self.spec = relational_db_spec
    self.engine = self.spec.engine
    self.engine_type = sql_engine_utils.GetDbEngineType(self.engine)
    self.instance_id = 'pkb-db-instance-' + FLAGS.run_uri
    self.port = self.GetDefaultPort()
    self.is_managed_db = self.IS_MANAGED
    self.endpoint = ''
    self.replica_endpoint = ''
    self.client_vms = []

  @property
  def client_vm(self):
    """Client VM which will drive the database test.

    This is required by subclasses to perform client-vm
    network-specific tasks, such as getting information about
    the VPC, IP address, etc.

    Raises:
      RelationalDbPropertyNotSetError: if the client_vm is missing.

    Returns:
      The client_vm.
    """
    if not hasattr(self, '_client_vm'):
      raise RelationalDbPropertyNotSetError('client_vm is not set')
    return self._client_vm

  # TODO(user): add support for multiple client VMs
  @client_vm.setter
  def client_vm(self, client_vm):
    self._client_vm = client_vm

  def _GetDbConnectionProperties(
      self,
  ) -> sql_engine_utils.DbConnectionProperties:
    return sql_engine_utils.DbConnectionProperties(
        self.spec.engine,
        self.spec.engine_version,
        self.endpoint,
        self.port,
        self.spec.database_username,
        self.spec.database_password,
    )

  # TODO(user): Deprecate in favor of client_vms_query_tools
  @property
  def client_vm_query_tools(self):
    return self.client_vms_query_tools[0]

  @property
  def client_vms_query_tools(self) -> list[sql_engine_utils.ISQLQueryTools]:
    connection_properties = self._GetDbConnectionProperties()
    tools = [
        sql_engine_utils.GetQueryToolsByEngine(vm, connection_properties)
        for vm in self.client_vms
    ]
    return [t for t in tools if t is not None]

  @property
  def client_vm_query_tools_for_replica(self):
    """Query tools to make custom queries on replica."""
    if not self.replica_endpoint:
      raise ValueError('Database does not have replica.')
    connection_properties = sql_engine_utils.DbConnectionProperties(
        self.spec.engine,
        self.spec.engine_version,
        self.replica_endpoint,
        self.port,
        self.spec.database_username,
        self.spec.database_password,
    )
    return sql_engine_utils.GetQueryToolsByEngine(
        self.client_vm, connection_properties
    )

  def SetVms(self, vm_groups):
    self.client_vms = vm_groups[
        'clients' if 'clients' in vm_groups else 'default'
    ]
    # TODO(user): Remove this after moving to multiple client VMs.
    self.client_vm = self.client_vms[0]

  def ClearWaitStats(self) -> None:
    if self.engine_type == sql_engine_utils.SQLSERVER:
      logging.info('Clearing wait stats')
      self.client_vm_query_tools.IssueSqlCommand(CLEAR_WAIT_STATS_SQL)

  def QueryIOStats(self) -> tuple[str, str]:
    if self.engine_type == sql_engine_utils.SQLSERVER:
      logging.info('Querying IO stats')
      return self.client_vm_query_tools.IssueSqlCommand(CAPTURE_IO_STATS_SQL)
    return ('', '')

  def QueryWaitStats(self) -> tuple[str, str]:
    if self.engine_type == sql_engine_utils.SQLSERVER:
      logging.info('Querying wait stats')
      with open(data.ResourcePath('capture_wait_stats.sql'), 'r') as f:
        return self.client_vm_query_tools.IssueSqlCommand(f.read())
    return ('', '')

  @property
  def port(self):
    """Port (int) on which the database server is listening."""
    if not hasattr(self, '_port'):
      raise RelationalDbPropertyNotSetError('port not set')
    return self._port

  @port.setter
  def port(self, port):
    self._port = int(port)

  def GetResourceMetadata(self):
    """Returns a dictionary of metadata.

    Child classes can extend this if needed.

    Raises:
       RelationalDbPropertyNotSetError: if any expected metadata is missing.
    """
    metadata = {
        'zone': self.spec.db_spec.zone,
        'disk_type': self.spec.db_disk_spec.disk_type,
        'disk_size': self.spec.db_disk_spec.disk_size,
        'engine': self.spec.engine,
        'high_availability': self.spec.high_availability,
        'backup_enabled': self.spec.backup_enabled,
        'engine_version': self.spec.engine_version,
        'client_vm_zone': self.spec.vm_groups['clients'].vm_spec.zone,
        'use_managed_db': self.is_managed_db,
        'instance_id': self.instance_id,
        'client_vm_disk_type': (
            self.spec.vm_groups['clients'].disk_spec.disk_type
        ),
        'client_vm_disk_size': (
            self.spec.vm_groups['clients'].disk_spec.disk_size
        ),
    }

    if (
        hasattr(self.spec.db_spec, 'machine_type')
        and self.spec.db_spec.machine_type
    ):
      metadata.update({
          'machine_type': self.spec.db_spec.machine_type,
      })
    elif hasattr(self.spec.db_spec, 'cpus') and (
        hasattr(self.spec.db_spec, 'memory')
    ):
      metadata.update({
          'cpus': self.spec.db_spec.cpus,
      })
      metadata.update({
          'memory': self.spec.db_spec.memory,
      })
    elif hasattr(self.spec.db_spec, 'tier') and (
        hasattr(self.spec.db_spec, 'compute_units')
    ):
      metadata.update({
          'tier': self.spec.db_spec.tier,
      })
      metadata.update({
          'compute_units': self.spec.db_spec.compute_units,
      })
    else:
      raise RelationalDbPropertyNotSetError(
          'Machine type of the database must be set.'
      )

    if (
        hasattr(self.spec.vm_groups['clients'].vm_spec, 'machine_type')
        and self.spec.vm_groups['clients'].vm_spec.machine_type
    ):
      metadata.update({
          'client_vm_machine_type': (
              self.spec.vm_groups['clients'].vm_spec.machine_type
          ),
      })
    elif hasattr(self.spec.vm_groups['clients'].vm_spec, 'cpus') and (
        hasattr(self.spec.vm_groups['clients'].vm_spec, 'memory')
    ):
      metadata.update({
          'client_vm_cpus': self.spec.vm_groups['clients'].vm_spec.cpus,
      })
      metadata.update({
          'client_vm_memory': self.spec.vm_groups['clients'].vm_spec.memory,
      })
    else:
      raise RelationalDbPropertyNotSetError(
          'Machine type of the client VM must be set.'
      )

    if FLAGS.db_flags:
      metadata.update({
          'db_flags': FLAGS.db_flags,
      })

    if hasattr(self.spec, 'db_tier'):
      metadata.update({
          'db_tier': self.spec.db_tier,
      })

    if hasattr(self.spec.db_disk_spec, 'provisioned_iops'):
      metadata.update({'disk_iops': self.spec.db_disk_spec.provisioned_iops})
    if hasattr(self.spec.db_disk_spec, 'provisioned_throughput'):
      metadata.update(
          {'disk_throughput_mb': self.spec.db_disk_spec.provisioned_throughput}
      )

    return metadata

  @abc.abstractmethod
  def GetDefaultEngineVersion(self, engine):
    """Return the default version (for PKB) for the given database engine.

    Args:
      engine: name of the database engine

    Returns: default version as a string for the given engine.
    """

  def _SetEndpoint(self):
    """Set the DB endpoint for this instance during _PostCreate."""
    pass

  def _PostCreate(self):
    if self.spec.db_flags:
      self._ApplyDbFlags()
    self._SetEndpoint()
    background_tasks.RunThreaded(
        lambda client_query_tools: client_query_tools.InstallPackages(),
        self.client_vms_query_tools,
    )

    # Add a traceroute command to the client VM to ensure that the database is
    # accessible. This also informs the baseline latency of the network.
    if self.client_vm.OS_TYPE in os_types.LINUX_OS_TYPES and self.endpoint:
      self.client_vm.RemoteCommand(
          'sudo apt-get install -y tcptraceroute', ignore_failure=True
      )
      self.client_vm.RemoteCommand(
          f'tcptraceroute {self.endpoint} {self.port}', ignore_failure=True
      )

  def UpdateCapacityForLoad(self) -> None:
    """Updates infrastructure to the correct capacity for loading."""
    pass

  def UpdateCapacityForRun(self) -> None:
    """Updates infrastructure to the correct capacity for running."""
    pass

  def _ApplyDbFlags(self):
    """Apply Flags on the database."""
    raise NotImplementedError(
        'Managed Db flags is not supported for %s' % type(self).__name__
    )

  def GetDefaultPort(self):
    """Returns default port for the db engine from the spec."""
    engine = sql_engine_utils.GetDbEngineType(self.spec.engine)
    if engine not in DEFAULT_PORTS:
      raise NotImplementedError(
          'Default port not specified for engine {}'.format(engine)
      )
    return DEFAULT_PORTS[engine]

  def CreateDatabase(self, database_name: str) -> tuple[str, str]:
    """Creates the database."""
    return self.client_vms_query_tools[0].CreateDatabase(database_name)

  def DeleteDatabase(self, database_name: str) -> tuple[str, str]:
    """Deletes the database."""
    return self.client_vms_query_tools[0].DeleteDatabase(database_name)

  def Failover(self):
    """Fail over the database.  Throws exception if not high available."""
    if not self.spec.high_availability:
      raise RuntimeError(
          "Attempt to fail over a database that isn't marked as high available"
      )
    self._FailoverHA()

  def _FailoverHA(self):
    """Fail over from master to replica."""
    raise NotImplementedError('Failover is not implemented.')

  def RestartDatabase(self):
    """Restarts all the database services in the benchmark.

    Args: None

    Returns: none
    """
    raise NotImplementedError('Restart database is not implemented.')

  def CollectMetrics(
      self, start_time: datetime.datetime, end_time: datetime.datetime
  ) -> list[sample.Sample]:
    """Collects and returns performance metrics after the run phase.

    This method is optional for subclasses to implement. Subclasses should
    query cloud-specific monitoring APIs for metrics within the provided time
    range and return them as a list of sample.Sample objects.

    Args:
      start_time: The start time of the run phase.
      end_time: The end time of the run phase.

    Returns:
      A list of sample.Sample instances containing the collected metrics.
    """
    del start_time, end_time
    return []
